import pickle
import numpy as np
import pandas as pd
import os
from read_data import read_data_from_path
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 20,
}
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def build_cnn_model(rate=0.2, initializer='glorot_uniform'):
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(8, 2, activation='relu', padding='same', 
                            input_shape=input_shape[1:], 
                            kernel_initializer=initializer))
    model.add(layers.Dropout(rate=rate))
    
    model.add(layers.Conv2D(16, 2, activation='relu', padding='same', 
                            input_shape=input_shape[1:],
                            kernel_initializer=initializer))
    model.add(layers.Dropout(rate=rate))
                                
    model.add(layers.Conv2D(8, 2, activation='relu', padding='same', 
                            input_shape=input_shape[1:],
                            kernel_initializer=initializer))
    model.add(layers.Dropout(rate=rate))
                                
    model.add(layers.AveragePooling2D(pool_size=(2,2), strides=(1,1),
                                        padding='valid'))
    model.add(layers.Flatten())
    model.add(layers.Dense(units=32, activation='relu',
                            kernel_initializer=initializer))
    model.add(layers.Dropout(rate=rate))

    model.add(layers.Dense(units=8, activation='relu',
                            kernel_initializer=initializer))
    model.add(layers.Dropout(rate=rate))

    model.add(layers.Dense(units=1, activation='linear'))
    
    loss_fn = tf.keras.losses.MeanSquaredError()
    
    optimizer = tf.keras.optimizers.Adam(
                    learning_rate=0.001, beta_1=0.9, beta_2=0.999, 
                    epsilon=1e-07, 
                    amsgrad=False,
                    name='Adam', 
                    )
    
    model.compile(
        optimizer=optimizer,  # Optimizer
        # Loss function to minimize
        loss= loss_fn,
        # List of metrics to monitor
        metrics =['mae']
    )  
    
    model.summary()
    
    return model

def plot_train_test(model, epochs=200, batch_size=200):
    filepath = r'./model_and_data/cnn_model.h5'
    checkpoint = ModelCheckpoint(filepath=filepath,monitor='val_mae',
                              verbose=1,save_best_only=True,mode='min')
    callback_list=[checkpoint]

    history = model.fit(
        X_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        # We pass some validation for
        # monitoring validation loss and metrics
        # at the end of each epoch
        validation_data=(X_test, y_test),
        callbacks=callback_list,
    )
    
    
    plt.figure(figsize=(12,4))

    plt.subplots_adjust(left=0.125, bottom=None, right=0.9, top=None,
                    wspace=0.3, hspace=None)
    plt.subplot(1,2,1)
    plt.plot(history.history['loss'],linewidth=3,label='Train')
    plt.plot(history.history['val_loss'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('loss',fontsize=20)
    plt.legend(prop=font1)

    plt.subplot(1,2,2)
    plt.plot(history.history['mae'],linewidth=3,label='Train')
    plt.plot(history.history['val_mae'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('Acc',fontsize=20)
    plt.legend(prop=font1)
    plt.show()
    

if __name__ == "__main__":

    cdn_data = pickle.load(open('./model_and_data/cdn_labels.pkl','rb'))
    ener_arr = pickle.load(open('./model_and_data/ener_arr.pkl', 'rb'))
    
    ener_arr = np.array(ener_arr).reshape((-1,1))
    cdn_data = cdn_data.iloc[:,:-1].to_numpy()
    min_max_scaler = MinMaxScaler()
    cdn_scale = min_max_scaler.fit_transform(cdn_data)
    cdn_scale = cdn_scale.reshape(999,20,3,1)
    
    standard_scaler = pickle.load(open('./model_and_data/scaler.pkl','rb'))
    ener_scale = standard_scaler.transform(ener_arr)
    
    input_shape = cdn_scale.shape
    print(input_shape)
    print(cdn_scale)

    pickle.dump(min_max_scaler, open(r'./model_and_data/min_max_scaler.pkl','wb'))

    X_train, X_test, y_train, y_test = train_test_split(
                                            cdn_scale,\
                                            ener_scale,\
                                            test_size=0.3,\
                                            random_state=0)
    
    model = build_cnn_model()
#    plot_train_test(model)
    
